/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file kernel_simt_cast_impl.h
 * \brief
 */
#ifndef ASCENDC_MODULE_SIMT_CAST_IMPL_H
#define ASCENDC_MODULE_SIMT_CAST_IMPL_H

#include "kernel_utils.h"
#include "kernel_simt_common_impl.h"

namespace AscendC {
namespace Simt {

#ifndef ASCENDC_CPU_DEBUG
#define REG_ROUND_VEC(type, func_name, len)                \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        for (int i = 0; i < len; i++) {                    \
            dst[i] = bisheng::cce::simt::func_name(src[i]);\
        }                                                  \
    }

#define REG_ROUND_VEC_(dst_type, src_type, d_type, s_type, func_name, len) \
    __aicore__ inline void func_name##_(dst_type &dst, src_type &src)      \
    {                                                                      \
        for (int i = 0; i < len; i++) {                                    \
            dst[i] = func_name##_<d_type, s_type>(src[i]);                 \
        }                                                                  \
    }

#define REG_CAST_IMPL_VEC(type, func_name) \
    REG_ROUND_VEC(type##1, func_name, 1)   \
    REG_ROUND_VEC(type##2, func_name, 2)   \
    REG_ROUND_VEC(type##3, func_name, 3)   \
    REG_ROUND_VEC(type##4, func_name, 4)

#define REG_CAST_IMPL_VEC_(dst_type, src_type, func_name)                      \
    REG_ROUND_VEC_(dst_type##1, src_type##1, dst_type, src_type, func_name, 1) \
    REG_ROUND_VEC_(dst_type##2, src_type##2, dst_type, src_type, func_name, 2) \
    REG_ROUND_VEC_(dst_type##3, src_type##3, dst_type, src_type, func_name, 3) \
    REG_ROUND_VEC_(dst_type##4, src_type##4, dst_type, src_type, func_name, 4)

#define REG_CAST_HF_IMPL_VEC(type, func_name) REG_ROUND_VEC(type##2, func_name, 2)

#define REG_CAST_HF_IMPL_VEC_(dst_type, src_type, func_name) \
    REG_ROUND_VEC_(dst_type##2, src_type##2, dst_type, src_type, func_name, 2)
#else
#define REG_ROUND_VEC_1(type, func_name)                   \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
    }

#define REG_ROUND_VEC_2(type, func_name)                   \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
    }

#define REG_ROUND_VEC_3(type, func_name)                   \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
        dst.z = func_name(src.z);                          \
    }

#define REG_ROUND_VEC_4(type, func_name)                   \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst.x = func_name(src.x);                          \
        dst.y = func_name(src.y);                          \
        dst.z = func_name(src.z);                          \
        dst.w = func_name(src.w);                          \
    }

#define REG_ROUND_VEC_1_(dst_type, src_type, d_type, s_type, func_name) \
    __aicore__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
    }

#define REG_ROUND_VEC_2_(dst_type, src_type, d_type, s_type, func_name) \
    __aicore__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
    }

#define REG_ROUND_VEC_3_(dst_type, src_type, d_type, s_type, func_name) \
    __aicore__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
        dst.z = func_name##_<d_type, s_type>(src.z);                    \
    }

#define REG_ROUND_VEC_4_(dst_type, src_type, d_type, s_type, func_name) \
    __aicore__ inline void func_name##_(dst_type &dst, src_type &src)   \
    {                                                                   \
        dst.x = func_name##_<d_type, s_type>(src.x);                    \
        dst.y = func_name##_<d_type, s_type>(src.y);                    \
        dst.z = func_name##_<d_type, s_type>(src.z);                    \
        dst.w = func_name##_<d_type, s_type>(src.w);                    \
    }

#define REG_CAST_IMPL_VEC(type, func_name) \
    REG_ROUND_VEC_1(type##1, func_name)    \
    REG_ROUND_VEC_2(type##2, func_name)    \
    REG_ROUND_VEC_3(type##3, func_name)    \
    REG_ROUND_VEC_4(type##4, func_name)

#define REG_CAST_IMPL_VEC_(dst_type, src_type, func_name)                     \
    REG_ROUND_VEC_1_(dst_type##1, src_type##1, dst_type, src_type, func_name) \
    REG_ROUND_VEC_2_(dst_type##2, src_type##2, dst_type, src_type, func_name) \
    REG_ROUND_VEC_3_(dst_type##3, src_type##3, dst_type, src_type, func_name) \
    REG_ROUND_VEC_4_(dst_type##4, src_type##4, dst_type, src_type, func_name)

#define REG_CAST_HF_IMPL_VEC(type, func_name) REG_ROUND_VEC_2(type##2, func_name)

#define REG_CAST_HF_IMPL_VEC_(dst_type, src_type, func_name) \
    REG_ROUND_VEC_2_(dst_type##2, src_type##2, dst_type, src_type, func_name)
#endif

#define REG_ROUND(type, func_name)                         \
    __aicore__ inline void func_name(type &dst, type &src) \
    {                                                      \
        dst = bisheng::cce::simt::func_name(src);          \
    }

#define REG_ROUND_(d_type, s_type, func_name)                     \
    __aicore__ inline void func_name##_(d_type &dst, s_type &src) \
    {                                                             \
        dst = func_name##_<d_type, s_type>(src);                  \
    }

#define REG_CAST_IMPL_(dst_type, src_type, func_name) REG_ROUND_(dst_type, src_type, func_name)

#define REG_CAST_VEC(round_mode)         \
    REG_CAST_IMPL_VEC(float, round_mode) \
    REG_CAST_HF_IMPL_VEC(half, round_mode)

#define REG_CAST_(round_mode)                \
    REG_CAST_IMPL_(half, float, round_mode)  \
    REG_CAST_IMPL_(int, float, round_mode)   \
    REG_CAST_IMPL_(long, float, round_mode)  \
    REG_CAST_IMPL_(bhalf, float, round_mode) \
    REG_CAST_IMPL_(float, half, round_mode)  \
    REG_CAST_IMPL_(float, int, round_mode)   \
    REG_CAST_IMPL_(float, long, round_mode)

#define REG_CAST_VEC_(round_mode)                   \
    REG_CAST_IMPL_VEC_(int, float, round_mode)      \
    REG_CAST_IMPL_VEC_(long, float, round_mode)     \
    REG_CAST_IMPL_VEC_(float, int, round_mode)      \
    REG_CAST_IMPL_VEC_(float, long, round_mode)     \
    REG_CAST_HF_IMPL_VEC_(float, half, round_mode)

#ifdef ASCENDC_CPU_DEBUG
REG_CAST_HF_IMPL_VEC_(half, float, Rint)
REG_CAST_HF_IMPL_VEC_(half, float, Floor)
REG_CAST_HF_IMPL_VEC_(half, float, Ceil)
REG_CAST_HF_IMPL_VEC_(half, float, Trunc)
REG_CAST_HF_IMPL_VEC_(half, float, CastNone)
#endif

REG_CAST_(Rint)
REG_CAST_(Floor)
REG_CAST_(Ceil)
REG_CAST_(Trunc)
REG_CAST_(CastNone)

REG_CAST_IMPL_(float, bfloat16_t, CastNone)
REG_CAST_IMPL_(float, bfloat16_t, Ceil)
REG_CAST_IMPL_(float, bfloat16_t, Floor)
REG_CAST_IMPL_(float, bfloat16_t, Trunc)
REG_CAST_IMPL_(float, bfloat16_t, Rint)

REG_CAST_VEC_(Rint)
REG_CAST_VEC_(Floor)
REG_CAST_VEC_(Ceil)
REG_CAST_VEC_(Trunc)
REG_CAST_VEC_(CastNone)

template <typename T1, typename T2, RoundMode roundMode>
__aicore__ inline T1 CastImpl(T2 x)
{
    T1 y;
    switch (roundMode) {
#if (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102)
        case RoundMode::CAST_EVEN:
            Rint_(y, x);
            break;
        case RoundMode::CAST_ZERO:
            Trunc_(y, x);
            break;
#endif
        case RoundMode::CAST_FLOOR:
            Floor_(y, x);
            break;
        case RoundMode::CAST_CEIL:
            Ceil_(y, x);
            break;
        case RoundMode::CAST_NONE:
            CastNone_(y, x);
            break;
        default:
            ASSERT(false && "Cast: An invalid RoundMode!");
            break;
    }
    return y;
}

template <typename T>
__aicore__ inline T RoundImpl(T x)
{
    return RoundIntrinsicsImpl(x);
}

template <typename T>
__aicore__ inline T RintImpl(T x)
{
    return RintIntrinsicsImpl(x);
}

template <typename T>
 __aicore__ inline T FloorImpl(T x)
{
    return FloorIntrinsicsImpl(x);
}

template <typename T>
__aicore__ inline T CeilImpl(T x)
{
    return CeilIntrinsicsImpl(x);
}

template <typename T>
__aicore__ inline T TruncImpl(T x)
{
    if (x > (T)0) {
        return FloorImpl(x);
    } else {
        return CeilImpl(x);
    }
}

}  // namespace Simt
}  // namespace AscendC
#endif  // ASCENDC_MODULE_SIMT_CAST_IMPL_H
